#!/usr/bin/env python3
import os
import math
import ptan
import time
import gymnasium as gym
# todo 使用pybullet替代
# import roboschool
import argparse
import pybullet
from tensorboardX import SummaryWriter

from lib import model, common

import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

# todo 使用HalfCheetah进行更新适配
# todo 使用HalfCheetahBulletEnv进行更新适配
ENV_ID = "RoboschoolHalfCheetah-v1"
GAMMA = 0.99
REWARD_STEPS = 5 # 奖励跳过的步数
BATCH_SIZE = 32 # 收集的游戏训练数缓冲区长度
LEARNING_RATE_ACTOR = 1e-5
LEARNING_RATE_CRITIC = 1e-3
ENTROPY_BETA = 1e-3
ENVS_COUNT = 16

TEST_ITERS = 100000


def test_net(net, env, count=10, device="cpu"):
    '''
    net: 待测试的网页
    env: 待测试的环境
    count: 测试的游戏轮次
    device: 执行的网络

    return: 平均回报奖励，平均步数
    '''
    rewards = 0.0
    steps = 0
    for _ in range(count):
        # 重置环境，开始新一轮游戏
        obs = env.reset()
        while True:
            obs_v = ptan.agent.float32_preprocessor([obs]).to(device)
            mu_v = net(obs_v)[0]
            action = mu_v.squeeze(dim=0).data.cpu().numpy()
            action = np.clip(action, -1, 1)
            # 执行动作
            obs, reward, done, _ = env.step(action)
            rewards += reward
            steps += 1
            if done:
                break
    return rewards / count, steps / count


def calc_logprob(mu_v, logstd_v, actions_v):
    '''
    使用高斯密度函数，计算连续值的概率
    '''
    p1 = - ((mu_v - actions_v) ** 2) / (2*torch.exp(logstd_v).clamp(min=1e-3))
    p2 = - torch.log(torch.sqrt(2 * math.pi * torch.exp(logstd_v)))
    return p1 + p2


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=False, action='store_true', help='Enable CUDA')
    parser.add_argument("-n", "--name", required=True, help="Name of the run")
    parser.add_argument("-e", "--env", default=ENV_ID, help="Environment id, default=" + ENV_ID)
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")

    save_path = os.path.join("saves", "a2c-" + args.name)
    os.makedirs(save_path, exist_ok=True)

    # 创建并行训练环境
    envs = [gym.make(args.env) for _ in range(ENVS_COUNT)]
    # 创建单个测试环境
    test_env = gym.make(args.env)

    # 创建动作预测网络
    net_act = model.ModelActor(envs[0].observation_space.shape[0], envs[0].action_space.shape[0]).to(device)
    # 创建动作评价网络
    net_crt = model.ModelCritic(envs[0].observation_space.shape[0]).to(device)
    print(net_act)
    print(net_crt)

    writer = SummaryWriter(comment="-a2c_" + args.name)
    agent = model.AgentA2C(net_act, device=device)
    exp_source = ptan.experience.ExperienceSourceFirstLast(envs, agent, GAMMA, steps_count=REWARD_STEPS)

    opt_act = optim.Adam(net_act.parameters(), lr=LEARNING_RATE_ACTOR)
    opt_crt = optim.Adam(net_crt.parameters(), lr=LEARNING_RATE_CRITIC)

    batch = []
    best_reward = None
    with ptan.common.utils.RewardTracker(writer) as tracker:
        with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
            for step_idx, exp in enumerate(exp_source):
                rewards_steps = exp_source.pop_rewards_steps()
                if rewards_steps:
                    "记录最近的一次游戏的回报奖励"
                    rewards, steps = zip(*rewards_steps)
                    tb_tracker.track("episode_steps", np.mean(steps), step_idx)
                    tracker.reward(np.mean(rewards), step_idx)

                if step_idx % TEST_ITERS == 0:
                    ts = time.time()
                    # 测试训练的网络
                    rewards, steps = test_net(net_act, test_env, device=device)
                    print("Test done in %.2f sec, reward %.3f, steps %d" % (
                        time.time() - ts, rewards, steps))
                    # 记录测试情况，并将测试效果好的网路存储起来
                    writer.add_scalar("test_reward", rewards, step_idx)
                    writer.add_scalar("test_steps", steps, step_idx)
                    if best_reward is None or best_reward < rewards:
                        if best_reward is not None:
                            print("Best reward updated: %.3f -> %.3f" % (best_reward, rewards))
                            name = "best_%+.3f_%d.dat" % (rewards, step_idx)
                            fname = os.path.join(save_path, name)
                            torch.save(net_act.state_dict(), fname)
                        best_reward = rewards

                batch.append(exp)
                # 如果收集的游戏数据缓冲区大小小于BATCh_SIZE，则继续收集
                if len(batch) < BATCH_SIZE:
                    continue

                # 提取游戏收集数据
                states_v, actions_v, vals_ref_v = \
                    common.unpack_batch_a2c(batch, net_crt, last_val_gamma=GAMMA ** REWARD_STEPS, device=device)
                # 清空缓存
                batch.clear()

                opt_crt.zero_grad()
                # 获取评价的Q值
                value_v = net_crt(states_v)
                # 第一个损失：评估的Q值与计算得到的Q值要接近
                loss_value_v = F.mse_loss(value_v.squeeze(-1), vals_ref_v)
                loss_value_v.backward()
                opt_crt.step()

                opt_act.zero_grad()
                # 获取执行的动作
                mu_v = net_act(states_v)
                # 用计算的Q值和评估的Q值作比较，用来衡量当前执行动作是否具备优势
                adv_v = vals_ref_v.unsqueeze(dim=-1) - value_v.detach()
                # 计算损失，最小化损失
                log_prob_v = adv_v * calc_logprob(mu_v, net_act.logstd, actions_v)
                loss_policy_v = -log_prob_v.mean()
                # 计算探索熵
                entropy_loss_v = ENTROPY_BETA * (-(torch.log(2*math.pi*torch.exp(net_act.logstd)) + 1)/2).mean()
                # 计算预判动作的损失
                # 这里好像将损失和熵分开计算了，为啥
                # 因为这里状态评价网络是独立的，所以评价网络是不再这里更新的
                loss_v = loss_policy_v + entropy_loss_v
                loss_v.backward()
                opt_act.step()

                # 计算记录
                tb_tracker.track("advantage", adv_v, step_idx)
                tb_tracker.track("values", value_v, step_idx)
                tb_tracker.track("batch_rewards", vals_ref_v, step_idx)
                tb_tracker.track("loss_entropy", entropy_loss_v, step_idx)
                tb_tracker.track("loss_policy", loss_policy_v, step_idx)
                tb_tracker.track("loss_value", loss_value_v, step_idx)
                tb_tracker.track("loss_total", loss_v, step_idx)

